clear

%set a name for the estimation output
estimation_name='2t_abc_rn';

%set number of types
n_types=2;

%set risk neutral or not
risk_neutral=1;         %if set to zero, takes r=1 for crra utility
                        %if set to one, assumes risk neutrality
                        %if set to two, estimates r as a free parameter

%set constraints
alpha_zero=0;           % sets alpha to zero
beta_zero=0;            % sets beta to zero
kappa_zero=0;           % sets kappa to zero
delta_zero=1;           % sets delta to zero
gamma_zero=1;           % sets gamma to zero
    
%settings for the EM-algorithm
n_steps = 24;           % number of starting points used
n_iterations = 50;      % maximal number of iterations
ll_diff_stop = .01;     % difference in log-likelihood for convergence
get_se = 1;             % set to 1 to get standard errors (long computing time)
nr = 60;                % number of replications for bootstrapping standard errors

%import choices from experiment
filename='mnl_data_core.csv';
dt=readmatrix(filename);

[m_dt,n_dt] = size(dt);

if risk_neutral==2
    n_params = 7;
    n_est_params = (1-alpha_zero)+(1-beta_zero)+(1-kappa_zero)+(1-delta_zero)+(1-gamma_zero)+3;
else
    n_params = 6;
    n_est_params = (1-alpha_zero)+(1-beta_zero)+(1-kappa_zero)+(1-delta_zero)+(1-gamma_zero)+2;
end

no_trials=18;
n = length(dt)/no_trials; % n is the number of subjects in the data

%sets with initial values
a_min = -0.6; a_max = 1.0;
b_min = -1.3; b_max = 1.1;
c_min = -0.3; c_max = 0.8;
d_min = -0.5; d_max = 0.5;
e_min = -0.5; e_max = 0.5;
l_min = 0.01; l_max = 14;
r_min = -0.1; r_max = 1.6;

pm0_initial = zeros(n_steps,n_types*n_params);

rng('default');
for type = 1:n_types
    if risk_neutral==2
        pm0_initial(:,n_params*type-0)=rand(n_steps,1)*(r_max-r_min)+r_min;   %initial values of rho
        pm0_initial(:,n_params*type-1)=rand(n_steps,1)*(l_max-l_min)+l_min;   %initial values of lambda
        pm0_initial(:,n_params*type-2)=rand(n_steps,1)*(e_max-e_min)+e_min;   %initial values of gamma
        pm0_initial(:,n_params*type-3)=rand(n_steps,1)*(d_max-d_min)+d_min;   %initial values of delta
        pm0_initial(:,n_params*type-4)=rand(n_steps,1)*(c_max-c_min)+c_min;   %initial values of kappa
        pm0_initial(:,n_params*type-5)=rand(n_steps,1)*(b_max-b_min)+b_min;   %initial values of beta
        pm0_initial(:,n_params*type-6)=rand(n_steps,1)*(a_max-a_min)+a_min;   %initial values of alpha
    else
        pm0_initial(:,n_params*type-0)=rand(n_steps,1)*(l_max-l_min)+l_min;   %initial values of lambda
        pm0_initial(:,n_params*type-1)=rand(n_steps,1)*(e_max-e_min)+e_min;   %initial values of gamma
        pm0_initial(:,n_params*type-2)=rand(n_steps,1)*(d_max-d_min)+d_min;   %initial values of delta
        pm0_initial(:,n_params*type-3)=rand(n_steps,1)*(c_max-c_min)+c_min;   %initial values of kappa
        pm0_initial(:,n_params*type-4)=rand(n_steps,1)*(b_max-b_min)+b_min;   %initial values of beta
        pm0_initial(:,n_params*type-5)=rand(n_steps,1)*(a_max-a_min)+a_min;   %initial values of alpha
    end
end
            
%constraints for the social and moral preference parameters
A_unc=zeros(n_params*n_types,n_params*n_types);
b_unc=zeros(1,n_params*n_types);
Aeq_unc=zeros(n_params*n_types,n_params*n_types);
beq_unc=zeros(1,n_params*n_types);
for type2 = 1:n_types
    if alpha_zero == 1
        Aeq_unc(n_params*type2-(n_params-1),n_params*type2-(n_params-1))=1;
        for type = 1:n_types
        if risk_neutral==2
            pm0_initial(:,n_params*type-6)=zeros(n_steps,1);   %initial values of alpha
        else
           pm0_initial(:,n_params*type-5)=zeros(n_steps,1);   %initial values of alpha
        end
        end
    end
    
    if beta_zero == 1
        Aeq_unc(n_params*type2-(n_params-2),n_params*type2-(n_params-2))=1;
        for type = 1:n_types
        if risk_neutral==2
            pm0_initial(:,n_params*type-5)=zeros(n_steps,1);   %initial values of beta
        else
           pm0_initial(:,n_params*type-4)=zeros(n_steps,1);   %initial values of beta
        end
        end
    end
    
    if kappa_zero == 1
        Aeq_unc(n_params*type2-(n_params-3),n_params*type2-(n_params-3))=1;
        for type = 1:n_types
        if risk_neutral==2
            pm0_initial(:,n_params*type-4)=zeros(n_steps,1);   %initial values of kappa
        else
           pm0_initial(:,n_params*type-3)=zeros(n_steps,1);   %initial values of kappa
        end
        end
    end
    
    if delta_zero == 1
        Aeq_unc(n_params*type2-(n_params-4),n_params*type2-(n_params-4))=1;
        for type = 1:n_types
        if risk_neutral==2
            pm0_initial(:,n_params*type-3)=zeros(n_steps,1);   %initial values of delta
        else
           pm0_initial(:,n_params*type-2)=zeros(n_steps,1);   %initial values of delta
        end
        end
    end  

    if gamma_zero == 1
        Aeq_unc(n_params*type2-(n_params-5),n_params*type2-(n_params-5))=1;
        for type = 1:n_types
        if risk_neutral==2
            pm0_initial(:,n_params*type-2)=zeros(n_steps,1);   %initial values of gamma
        else
           pm0_initial(:,n_params*type-1)=zeros(n_steps,1);   %initial values of gamma
        end
        end
    end 
end

%store results of each starting value                
fit_subresults = zeros(n_steps,4);                  
pm_subresults = zeros(n_steps,n_types*n_params);    
pm0_subresults = zeros(n_steps,n_types*n_params);
pi_subresults = zeros(n_steps,n_types);
tau_subresults = zeros(n,n_types,n_steps);

setts = [n_types,n,no_trials];

% start EM algorithm
parfor s = 1:n_steps
   
    pis = zeros(n_iterations,n_types);
    pms = zeros(n_iterations,n_types*n_params);
    fits = zeros(n_iterations,4);
    taus = zeros(n,n_types,n_iterations);
    ind_lls = zeros(n,n_types,n_iterations);

    weighted = zeros(n,n_types);
    iter = 0;
    ll_diff = 100;
    
while (ll_diff>ll_diff_stop && iter < n_iterations)
    
    iter = iter+1;
    
    if iter==1        
            pm0 = pm0_initial(s,:);                         % starting values for preference parameters    
            pi0 = (1/n_types)*ones(n_types,1);              % initial values of phi (mixtures)
            pm0_subresults(s,:)=pm0;                 
            ind_ll0=lolik_mix_estep(pm0,dt,setts,risk_neutral);        
    else  
            pm0 = pms(iter-1,:);                            % starting values for preference parameters
            pi0 = pis(iter-1,:);                            % initial values of phi (mixtures)
            ind_ll0=ind_lls(:,:,iter-1);      
    end
            
    ind_l0=exp(ind_ll0);
            
            for gg = 1:n_types
                for nn = 1:n
                    weighted(nn,gg)=pi0(gg)*ind_l0(nn,gg);
                end
            end
            
            for gg2 = 1:n_types
                for nn2 = 1:n
                    taus(nn2,gg2,iter)= weighted(nn2,gg2)/(sum(weighted(nn2,:)));
                end
            end
    
    taus1=taus(:,:,iter);
    pi1=mean(taus1);
    pis(iter,:)=pi1;
    
    try
        f_unc=@(pm2)lolik_mix(pm2,dt,setts,taus1,pi1,risk_neutral);
        [pm2,fval] = fmincon(f_unc,pm0,A_unc,b_unc,Aeq_unc,beq_unc);
        il=lolik_mix_estep(pm2,dt,setts,risk_neutral);
    catch
        warning('Problem using function.');
        fval = 99999;
        pm2 = 99*ones(1,n_params*n_types);
        il = -99*ones(n,n_types);
    end
      
    pms(iter,:) = pm2;                                                              % estimated preference parameters
    fits(iter,4) = -fval;                                                           % approximated log-likelihood M-Step ("Q-value")
    ind_lls(:,:,iter) = il;
    ind_lik = exp(il);
    tot_loglik = sum(log(ind_lik*pi1'));            
    fits(iter,1) = tot_loglik;                                                      % log-likelihood
    fits(iter,3) = -trace(taus1*log(taus1'));                                       % entropy
    fits(iter,2) = -2*(tot_loglik)+(n_types*n_est_params-1)*log(n)+fits(iter,3);    % (approximated) ICL criterion (see McLachlan et al 2019) 

    if iter>1
        ll_diff = fits(iter,1)-fits(iter-1,1);
    end

    logl=fits(1:iter,1);
    [M,I]=max(logl);
    
    fit_subresults(s,:) = fits(I,:); 
    pm_subresults(s,:) = pms(I,:);
    pi_subresults(s,:) = pis(I,:);
    tau_subresults(:,:,s) = taus(:,:,I);
    
end

end

logl2=fit_subresults(:,1);
[M2,I2]=max(logl2);

fit_results = fit_subresults(I2,:); 
pm_results = pm_subresults(I2,:);
pm0_results = pm0_subresults(I2,:);
pi_results = pi_subresults(I2,:);
tau_results = tau_subresults(:,:,I2);

pm_matrix = reshape(pm_results,[], n_types);


% Bootstrapping SEs
% create arrays to store results of bootstrap replications
fit_subresults_bstrp = zeros(nr,4);                  
pm_subresults_bstrp = zeros(nr,n_types*n_params);    
pm0_subresults_bstrp = zeros(nr,n_types*n_params);
pi_subresults_bstrp = zeros(nr,n_types);
tau_subresults_bstrp = zeros(n,n_types,nr);
se_pms = zeros(1,n_params*n_types);
se_pis = zeros(1,n_types);
    
% start bootstrapping of standard errors
if get_se == 1
    btstrp_ids = zeros(n,nr); 
    for z = 1:nr
        btstrp_ids(:,z) = randsample(n,n,true);
    end
    
    pm_start = pm_results;
    
    parfor s = 1:nr
    
        pm0 = pm_start;
        pis = zeros(n_iterations,n_types);
        pms = zeros(n_iterations,n_types*n_params);
        fits = zeros(n_iterations,4);
        taus = zeros(n,n_types,n_iterations);
        ind_lls = zeros(n,n_types,n_iterations);

        weighted = zeros(n,n_types);
        iter = 0;
        ll_diff = 100;
    
        btstrp_dt = zeros(m_dt,n_dt);
    
        for i = 1:n
            b_id =  btstrp_ids(i,s);
            m1 = no_trials*(b_id-1)+1; m2 = no_trials*b_id;
            n1 = no_trials*(i-1)+1; n2 = no_trials*i;
            btstrp_dt(n1:n2,:) = dt(m1:m2,:);
        end
    
        while (ll_diff>ll_diff_stop && iter < n_iterations)
    
            iter = iter+1;
    
            if iter==1        
                pi0 = (1/n_types)*ones(n_types,1);       %initial values of pi               
                pm0_subresults_bstrp(s,:)=pm0;
                ind_ll0=lolik_mix_estep(pm0,btstrp_dt,setts,risk_neutral);
            else
             	pm0 = pms(iter-1,:);
                pi0 = pis(iter-1,:);       %initial values of pi
                ind_ll0=ind_lls(:,:,iter-1);
            end
            
                ind_l0=exp(ind_ll0);
            
            for gg = 1:n_types
                for nn = 1:n
                    weighted(nn,gg)=pi0(gg)*ind_l0(nn,gg);
                end
            end
            
            for gg2 = 1:n_types
                for nn2 = 1:n
                    taus(nn2,gg2,iter)= weighted(nn2,gg2)/(sum(weighted(nn2,:)));
                end
            end
            
            taus1=taus(:,:,iter);
            pi1=mean(taus1);
            pis(iter,:)=pi1;

            try
                f_unc=@(pm2)lolik_mix(pm2,btstrp_dt,setts,taus1,pi1,risk_neutral);
                [pm2,fval] = fmincon(f_unc,pm0,A_unc,b_unc,Aeq_unc,beq_unc);
                il=lolik_mix_estep(pm2,btstrp_dt,setts,risk_neutral);
            catch
                warning('Problem using function.');
                fval = 99999;
                pm2 = 99*ones(1,n_params*n_types);
                il = -99*ones(n,n_types);
            end

            pms(iter,:) = pm2;                                                              % estimated preference parameters
            fits(iter,4) = -fval;                                                           % approximated log-likelihood M-Step ("Q-value")
            ind_lls(:,:,iter) = il;
            ind_lik = exp(il);
            tot_loglik = sum(log(ind_lik*pi1'));            
            fits(iter,1) = tot_loglik;                                                      % log-likelihood

            if iter>1
                ll_diff = fits(iter,1)-fits(iter-1,1);
            end

            logl=fits(1:iter,1);
            [M,I]=max(logl);
    
            fit_subresults_bstrp(s,:) = fits(I,:); 
            pm_subresults_bstrp(s,:) = pms(I,:);
            pi_subresults_bstrp(s,:) = pis(I,:);
            tau_subresults_bstrp(:,:,s) = taus(:,:,I);
    
        end

    end
    
    for p = 1:n_params*n_types
        se_pms(p) = std(pm_subresults_bstrp(:,p));
    end

    for p2 = 1:n_types
        se_pis(p2) = std(pi_subresults_bstrp(:,p2));
    end
end

pm_se_bstrp = reshape(se_pms,[], n_types);

% store the results
estimates_output = [pm_matrix;pi_results;pm_se_bstrp;se_pis];

save(fullfile('output','mixtures','mat_files',estimation_name));

T1 = table(tau_results);
writetable(T1,fullfile('output','mixtures','taus',estimation_name));

T2 = table(estimates_output);
writetable(T2,fullfile('output','mixtures','estimates',estimation_name));

T3 = table(fit_results);
writetable(T3,fullfile('output','mixtures','fits',estimation_name));
